# coding=utf-8
# Copyright 2024 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""TensorFlow-based trainer for graph embedding with GloVe model."""

# pylint: disable=use-symbolic-message-instead
# pylint: disable=missing-docstring
# pylint: disable=invalid-name
# pylint: disable=C6120,C6113

from __future__ import print_function

import copy
import os
import struct
import glove_util
import numpy as np
import tensorflow as tf


METRICS_TO_PRINT = {
    'mat_msq': lambda m: np.mean(m**2.0),
    'mat_avg': lambda m: np.mean(m),  # pylint: disable=unnecessary-lambda
    'mat_max_avg': lambda m: np.max(np.abs(np.mean(m, axis=1))),
    'mat_max_msq': lambda m: np.max(np.mean(m**2.0, axis=1)),
    'm00': lambda m: float(m[0, 0]),
    'm10': lambda m: float(m[1, 0]),
    'avg_mag': lambda m: np.mean(np.sum(m**2.0, axis=1))
}


class Discriminator(object):

  def __init__(self, n_input, n_hidden):
    self.n_hidden = n_hidden
    self.n_input = n_input

  def net(self, inputs, reuse=False):
    with tf.variable_scope('discriminator') as scope:
      if reuse:
        scope.reuse_variables()
      d_W1 = tf.get_variable(
          'd_W1', [self.n_input, self.n_hidden],
          initializer=tf.random_normal_initializer(stddev=0.01))
      d_b1 = tf.get_variable(
          'd_b1', [self.n_hidden], initializer=tf.constant_initializer(0))
      d_W2 = tf.get_variable(
          'd_W2', [self.n_hidden, 1],
          initializer=tf.random_normal_initializer(stddev=0.01))
      d_b2 = tf.get_variable(
          'd_b2', [1], initializer=tf.constant_initializer(0))

    hidden = tf.nn.relu(tf.matmul(inputs, d_W1) + d_b1)
    output = tf.nn.sigmoid(tf.matmul(hidden, d_W2) + d_b2)

    return output


class GloVeModelTf(object):
  """A class to do GloVe training given cooccurrences and sorted vocab file.

  Mimics behavior and performance of the original c implementation.
  Requires files generated by code at https://github.com/stanfordnlp/GloVe.
  ***New: can also take a raw list of sentences and generate cooccurrences.

  Args:
    vector_size: desired size of the model
    vocab_filename: name of the vocab file for building the model
  """

  def _weight_initializer(self,
                          weight_name,
                          init_width,
                          rowdim,
                          coldim,
                          load_name=None):
    """Initializes weights either from file or from uniform initializer.

    Args:
      weight_name: name for the weights in the graph
      init_width: width of the uniform initializer
      rowdim: row dimension size of the weights
      coldim: col dimension size of the weights
      load_name: set only if load name is different from variable name

    Returns:
      a tf variable initialized with weight_name
    """
    if not load_name:
      load_name = weight_name
    if self._init_weight_dir:
      with open(
          os.path.join(self._init_weight_dir, '%s.txt' % load_name), 'rb') as f:
        weights = np.reshape(np.loadtxt(f), newshape=[rowdim, coldim])
    else:
      weights = tf.random_uniform([rowdim, coldim], -init_width, init_width)
    return tf.Variable(weights, name=weight_name, dtype=tf.float32)

  def __init__(self,
               vector_size,
               vocab_filename=None,
               covariate_size=0,
               random_seed=12345,
               init_weight_dir=None,
               random_walks=None,
               covariate_data=None,
               window_size=5):
    """Initializes the data reading and model variables.

    Args:
      vector_size: size of the word vectors.
      vocab_filename: filename for getting word tokens.
      covariate_size: size of the covariate embedding dimension
      random_seed: seed the initialization generator
      init_weight_dir: directory to pull initial weights from. defaults to a
        uniform initializer if none.
      random_walks: a list of tokenized sentences
      covariate_data: a keyed list of float lists, where each key identifies a
        token in the corpus, and each float list is a row of covariate data
      window_size: window size to use for cooccurrence counting, if needed
    Returns: (none)
    """
    print('setting up basic stuff...')
    # Get word tokens
    self._vector_size = vector_size
    self._covariate_size = covariate_size
    self._tokens = []
    self._vocab_index_lookup = None
    if vocab_filename:
      with open(vocab_filename, 'r') as f:
        for line in f:
          self._tokens.append(line.split()[0])
      self._vocab_index_lookup = dict(
          zip(self._tokens, list(range(len(self._tokens)))))
    self._cooccurrences = None
    self._cooccurrence_dict = None
    print('loading or computing co-occurrences...')
    if random_walks:
      (self._cooccurrences, self._tokens, self._vocab_index_lookup,
       self._cooccurrence_dict) = glove_util.count_cooccurrences(
           random_walks, window_size, self._vocab_index_lookup)
    self._vocab_size = len(self._tokens)

    # Get covariate data
    print('setting other placeholders...')
    if covariate_data is not None:
      self._covariate_data = np.array([covariate_data[t] for t in self._tokens])

    # Placeholders for parameter tensors and other trackers
    io_dict = {'input': None, 'outpt': None}
    self._word = copy.deepcopy(io_dict)
    self._bias = copy.deepcopy(io_dict)
    self._iter = 0
    self._sum_cost = 0
    self._sum_adv_cost_g = 0
    self._sum_adv_cost_d = 0
    self._random_seed = random_seed
    self._init_weight_dir = init_weight_dir

    # Pointers to variables needed for covariate model
    self._cvrt = copy.deepcopy(io_dict)
    self._cvrt_transformation = copy.deepcopy(io_dict)

    # Initialize the cooccurrence read format
    self._cooccurrence_fmt = 'iid'
    self._cooccurrence_fmt_length = struct.calcsize(self._cooccurrence_fmt)
    self._struct_unpack = struct.Struct(self._cooccurrence_fmt).unpack_from

  def _compute_loss_weight(self, y):
    """Computes the loss weighting function as defined in the original paper.

    Args:
      y: the raw (un-logged) cooccurrence score

    Returns:
      weighted loss
    """
    return 1.0 if y > self._xmax else pow(y / self._xmax, self._alpha)

  def _cap_logits(self, logits, cap):
    logits = tf.math.minimum(logits, cap)
    return tf.math.maximum(logits, -cap)

  def _compute_w2v_loss(self, input_indxs):
    """define part of tensorflow graph to compute the w2v loss.

    Args:
      input_indxs: int32 or int64 tensor for input labels

    Returns:
      w2v training loss
    """
    # If using word2vec, we need negative samples
    print('registering w2v loss in the graph')
    word_neg_sample_loss_dot_products = tf.reduce_sum(
        tf.multiply(
            tf.nn.embedding_lookup(self._word['outpt'],
                                   self._neg_samples_tensor),
            tf.expand_dims(self._input_word_vecs, 1)),
        axis=2)
    if self._covariate_size > 0:
      cvrt_neg_sample_loss_dot_products = tf.reduce_sum(
          tf.multiply(
              tf.nn.embedding_lookup(self._cvrt['outpt'],
                                     self._neg_samples_tensor),
              tf.expand_dims(self._input_cvrt_vecs, 1)),
          axis=2)
    else:
      cvrt_neg_sample_loss_dot_products = tf.fill(
          tf.shape(word_neg_sample_loss_dot_products), 0.0)
    bias_neg_sample_loss_sum = tf.squeeze(
        tf.math.add(
            tf.nn.embedding_lookup(self._bias['outpt'],
                                   self._neg_samples_tensor),
            tf.expand_dims(
                tf.nn.embedding_lookup(self._bias['input'], input_indxs), 1)))
    neg_sample_loss_logits = (
        word_neg_sample_loss_dot_products + cvrt_neg_sample_loss_dot_products +
        bias_neg_sample_loss_sum)
    neg_sample_loss_logits = self._cap_logits(neg_sample_loss_logits,
                                              self._w2v_logit_max)
    neg_sample_loss_values = self._scores * tf.reduce_sum(
        tf.math.log(tf.math.sigmoid(-neg_sample_loss_logits)),
        axis=1) / (2.0 * self._w2v_neg_sample_scale)
    pos_loss_logits = self._est_score
    pos_loss_logits = self._cap_logits(pos_loss_logits, self._w2v_logit_max)
    pos_loss_values = self._scores * tf.math.log(
        tf.math.sigmoid(pos_loss_logits)) / 2.0
    self._diff = -1.0 * pos_loss_values
    return -1.0 * tf.reduce_sum(pos_loss_values + neg_sample_loss_values)

  def _forward(self, input_indxs, outpt_indxs, scores, weights):
    """Build the graph for the forward pass.

    Args:
      input_indxs: int32 or int64 tensor for input labels
      outpt_indxs: int32 or int64 tensor for outpt labels
      scores: float32 tensor for co-occurrence score
      weights: float32 tensor for loss weights

    Returns:
      loss: a univariate tensor giving the loss from the batch
    """
    # Initialize input/outpt word (node) parameters
    self._default_scope = tf.get_variable_scope()
    init_width = 0.5 / (self._vector_size + self._covariate_size)
    self._word['input'] = self._weight_initializer('word_input', init_width,
                                                   self._vocab_size,
                                                   self._vector_size)
    self._word['outpt'] = self._weight_initializer('word_outpt', init_width,
                                                   self._vocab_size,
                                                   self._vector_size)

    # Initialize input/outpt bias parameters
    self._bias['input'] = self._weight_initializer('bias_input', init_width,
                                                   self._vocab_size, 1)
    self._bias['outpt'] = self._weight_initializer('bias_outpt', init_width,
                                                   self._vocab_size, 1)

    if self._covariate_size > 0:
      # Initialize input/outpt cvrt transformation parameters
      self._cvrt_transformation['input'] = self._weight_initializer(
          'cvrt_input', init_width, self._covariate_data.shape[1],
          self._covariate_size)
      self._cvrt_transformation['outpt'] = self._weight_initializer(
          'cvrt_outpt', init_width, self._covariate_data.shape[1],
          self._covariate_size)

      # Project the covariate data with the transformation parameters
      self._cvrt['input'] = tf.matmul(self._covariate_data_tensor,
                                      self._cvrt_transformation['input'])
      self._cvrt['outpt'] = tf.matmul(self._covariate_data_tensor,
                                      self._cvrt_transformation['outpt'])

      if self._use_monet:
        # Compute covariate svd
        _, self._u, _ = tf.linalg.svd(self._cvrt['input'] + self._cvrt['outpt'])

        # Project base word vecs and get word vecs
        self._projected_word_input = tf.stop_gradient(
            self._word['input'] - self._db_level * tf.matmul(
                self._u, tf.matmul(tf.transpose(self._u), self._word['input'])))
        self._projected_word_outpt = tf.stop_gradient(
            self._word['outpt'] - self._db_level * tf.matmul(
                self._u, tf.matmul(tf.transpose(self._u), self._word['outpt'])))

    # Get loss input word vectors
    if self._use_monet:
      self._input_word_vecs = tf.nn.embedding_lookup(self._projected_word_input,
                                                     input_indxs)
      self._outpt_word_vecs = tf.nn.embedding_lookup(self._projected_word_outpt,
                                                     outpt_indxs)
    else:
      self._input_word_vecs = tf.nn.embedding_lookup(self._word['input'],
                                                     input_indxs)
      self._outpt_word_vecs = tf.nn.embedding_lookup(self._word['outpt'],
                                                     outpt_indxs)

    # Get loss input bias vectors
    self._input_bias_vecs = tf.nn.embedding_lookup(self._bias['input'],
                                                   input_indxs)
    self._outpt_bias_vecs = tf.nn.embedding_lookup(self._bias['outpt'],
                                                   outpt_indxs)
    self._word_pred = tf.reduce_sum(
        tf.multiply(self._input_word_vecs, self._outpt_word_vecs), axis=1)
    self._bias_pred = tf.reduce_sum(
        self._input_bias_vecs + self._outpt_bias_vecs, axis=1)
    estimated_score = self._bias_pred
    self._word_pred = tf.reduce_sum(
        tf.multiply(self._input_word_vecs, self._outpt_word_vecs), axis=1)
    estimated_score += self._word_pred

    # Add covariate terms
    if self._covariate_size > 0:
      self._input_cvrt_vecs = tf.nn.embedding_lookup(self._cvrt['input'],
                                                     input_indxs)
      self._outpt_cvrt_vecs = tf.nn.embedding_lookup(self._cvrt['outpt'],
                                                     outpt_indxs)
      self._cvrt_pred = tf.reduce_sum(
          tf.multiply(self._input_cvrt_vecs, self._outpt_cvrt_vecs), axis=1)
      estimated_score += self._cvrt_pred
    else:
      self._cvrt_pred = tf.constant(0.0)

    self._scores = scores
    self._est_score = estimated_score
    if self._use_w2v:
      loss = self._compute_w2v_loss(input_indxs)
    else:
      diff = estimated_score - scores
      self._diff = diff
      loss = tf.reduce_sum(tf.multiply(weights, tf.square(diff))) / 2
    return loss

  def _monet_train_op(self, optimizer, loss, global_step):
    """Registers the MONET training op in the graph.

    Args:
      optimizer: a tf optimizer object
      loss: the loss to optimize from a tf graph
      global_step: train step for the network

    Returns:
      a tf train op
    """
    # Compute gradients
    var_list = [
        self._cvrt_transformation['input'], self._cvrt_transformation['outpt'],
        self._bias['input'], self._bias['outpt'], self._projected_word_input,
        self._projected_word_outpt
    ]
    # Point the model word vector gradients to the base word vector gradients
    grads_and_vars = optimizer.compute_gradients(loss, var_list)
    grads_and_vars[-2] = (grads_and_vars[-2][0], self._word['input'])
    grads_and_vars[-1] = (grads_and_vars[-1][0], self._word['outpt'])
    self._grads_and_vars = grads_and_vars
    return optimizer.apply_gradients(grads_and_vars, global_step)

  def _register_adversary(self):
    """Build adversary part of graph.

    Args: (none)
    Returns:
      train_d: adversary trainer
      train_g: generator trainer
    """
    adv_positives = tf.placeholder(tf.int32, shape=(None))
    adv_negatives = tf.placeholder(tf.int32, shape=(None))
    self._adv_positives = adv_positives
    self._adv_negatives = adv_negatives

    d = Discriminator(self._vector_size, self._adv_dim)

    d_neg = d.net(
        tf.nn.embedding_lookup(self._word['input'] + self._word['outpt'],
                               self._adv_negatives),
        reuse=False)
    self._d_neg = d_neg
    d_pos = d.net(
        tf.nn.embedding_lookup(self._word['input'] + self._word['outpt'],
                               self._adv_positives),
        reuse=True)
    self._d_pos = d_pos

    loss_d = tf.reduce_mean(tf.log(d_pos)) + tf.reduce_mean(tf.log(1 - d_neg))
    loss_g = tf.reduce_mean(tf.log(d_neg)) + tf.reduce_mean(tf.log(1 - d_pos))

    self._adv_loss_d = loss_d
    self._adv_loss_g = loss_g

    d_var_list = tf.get_collection(
        tf.GraphKeys.GLOBAL_VARIABLES, scope='discriminator')
    # g_var_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
    #                                scope=self._default_scope)
    g_var_list = [
        var for var in tf.trainable_variables()
        if 'discriminator' not in var.name
    ]
    print(d_var_list)
    print(g_var_list)

    train_d = tf.train.AdamOptimizer(self._adv_lr).minimize(
        -loss_d, var_list=d_var_list)
    train_g = tf.train.AdamOptimizer(self._adv_lr).minimize(
        -loss_g, var_list=g_var_list)
    return train_d, train_g

  def _build_graph(self):
    """Build the graph for the full model.

    Args: (none)
    Returns: (none)
    """

    # Make placeholders and covariate data
    input_indxs = tf.placeholder(tf.int32, shape=(None))
    outpt_indxs = tf.placeholder(tf.int32, shape=(None))
    neg_samples = tf.placeholder(tf.int32, shape=(None, self._neg_samples))
    scores = tf.placeholder(tf.float32, shape=(None))
    weight = tf.placeholder(tf.float32, shape=(None))
    self._input_indxs = input_indxs
    self._outpt_indxs = outpt_indxs
    self._neg_samples_tensor = neg_samples
    self._scores = scores
    self._weight = weight
    if self._covariate_size > 0:
      covariate_data = tf.Variable(
          self._covariate_data, trainable=False, dtype=tf.float32)
      self._covariate_data_tensor = covariate_data

    # Initialize global step
    global_step = tf.Variable(0, name='global_step')
    self._global_step = global_step

    # Make the feed-forward network
    tf.random.set_random_seed(self._random_seed)
    loss = self._forward(input_indxs, outpt_indxs, scores, weight)
    self._loss = loss

    # Optimize the loss
    optimizer = tf.train.AdagradOptimizer(
        self._learning_rate, initial_accumulator_value=1.0)
    if self._use_monet:
      train = self._monet_train_op(optimizer, loss, global_step)
    else:
      train = optimizer.minimize(loss, global_step)
    self._train = train

    # Get adv trainers
    if self._adv_labels:
      self._train_d, self._train_g = self._register_adversary()

    # Initialize the variables
    tf.global_variables_initializer().run()
    self._saver = tf.train.Saver()

  def _save_weights(self, weight_filename, weights):
    """Saves weights using glove_util.KeyedVectors object.

    Args:
      weight_filename: filename to save to
      weights: weights to save
    Returns: the weight object, if return_no_save is True
    """
    embedding_obj = glove_util.KeyedVectors(weights.shape[1])
    embedding_obj.add(self._tokens, weights)
    embedding_obj.save_word2vec_format(weight_filename)

  def _extract_weight_matrix(self, model):
    return np.array([model[t] for t in self._tokens])

  def _print_extra_diagnostics(self, session):
    """diagnostic printing.

    Args:
        session: a tensorflow session
    Returns: (none)
    """
    # Get data for printing
    weight_mats = {
        s: self._extract_weight_matrix(w)
        for s, w in self._get_model_return_dict(session).items()
        if w is not None
    }
    bias_input, bias_outpt = session.run(
        [self._bias['input'], self._bias['outpt']])
    if self._use_monet:
      projected_word_input, projected_word_outpt, u = session.run(
          [self._projected_word_input, self._projected_word_outpt, self._u])
    # Print avg, msq, max row avg, and max row msq of each weight matrix
    for metric in sorted(METRICS_TO_PRINT):
      print('word input %s: %0.9f' %
            (metric, METRICS_TO_PRINT[metric](weight_mats['topo_input'])))
      print('word outpt %s: %0.9f' %
            (metric, METRICS_TO_PRINT[metric](weight_mats['topo_outpt'])))
      print('bias input %s: %0.9f' %
            (metric, METRICS_TO_PRINT[metric](bias_input)))
      print('bias outpt %s: %0.9f' %
            (metric, METRICS_TO_PRINT[metric](bias_outpt)))
      if self._covariate_size > 0:
        print('cvrt input %s: %0.9f' %
              (metric, METRICS_TO_PRINT[metric](weight_mats['meta_input'])))
        print('cvrt outpt %s: %0.9f' %
              (metric, METRICS_TO_PRINT[metric](weight_mats['meta_outpt'])))
        print(
            'cvrt trans input %s: %0.9f' %
            (metric, METRICS_TO_PRINT[metric](weight_mats['meta_trans_input'])))
        print(
            'cvrt trans outpt %s: %0.9f' %
            (metric, METRICS_TO_PRINT[metric](weight_mats['meta_trans_outpt'])))
        if self._use_monet:
          print('projected word input %s: %0.9f' %
                (metric, METRICS_TO_PRINT[metric](projected_word_input)))
          print('projected word outpt %s: %0.9f' %
                (metric, METRICS_TO_PRINT[metric](projected_word_outpt)))
          print('projected word final %s: %0.9f' %
                (metric, METRICS_TO_PRINT[metric](projected_word_outpt +
                                                  projected_word_input)))
          print('u from svd %s: %0.9f' % (metric, METRICS_TO_PRINT[metric](u)))

  def _get_batch(self, cooccurrence_filename, cooccurrence_count, max_count, f):
    """Gets a batch to train on.

    Args:
      cooccurrence_filename: where cooccurrences are
      cooccurrence_count: where we are in the cooccurrence list
      max_count: maximum cooccurrence count
      f: potentially an open file object for the co-occurrences

    Returns:
      lots of stuff
    """
    batch_data = {
        'input_indxs': [],
        'outpt_indxs': [],
        'scores': [],
        'weight': []
    }
    if self._adv_labels:
      batch_data['positive_indxs'] = []
      batch_data['negative_indxs'] = []
    continue_training = True
    for _ in range(self._batch_size):
      # Get an example
      if cooccurrence_filename:
        last_data_read = f.read(self._cooccurrence_fmt_length)
        if not last_data_read:
          continue_training = False
        batch_example = self._struct_unpack(last_data_read)
      else:
        if cooccurrence_count == max_count:
          continue_training = False
          break
        batch_example = self._cooccurrences[cooccurrence_count]
        cooccurrence_count += 1

      # Store the example
      batch_data['input_indxs'].append(batch_example[0] - 1)
      batch_data['outpt_indxs'].append(batch_example[1] - 1)
      batch_data['scores'].append(np.log(batch_example[2]))
      batch_data['weight'].append(self._compute_loss_weight(batch_example[2]))
      if self._adv_labels:
        if self._adv_labels[self._tokens[batch_example[0] - 1]][0] == 1.0:
          batch_data['positive_indxs'].append(batch_example[0] - 1)
        else:
          batch_data['negative_indxs'].append(batch_example[0] - 1)
        if self._adv_labels[self._tokens[batch_example[1] - 1]][0] == 1.0:
          batch_data['positive_indxs'].append(batch_example[1] - 1)
        else:
          batch_data['negative_indxs'].append(batch_example[1] - 1)
    batch_data['neg_samples'] = np.random.randint(
        0,
        self._vocab_size,
        size=(len(batch_data['scores']), self._neg_samples),
        dtype=np.int32)
    batch_data['neg_samples'] = [list(v) for v in batch_data['neg_samples']]
    return batch_data, continue_training

  def _print_weights(self, session, batch_data):
    word_input, word_outpt, bias_input, bias_outpt = session.run([
        self._word['input'], self._word['outpt'], self._bias['input'],
        self._bias['outpt']
    ])
    print('W is:')
    print(word_input[batch_data['input_indxs'][-1], :3])
    print(word_outpt[batch_data['outpt_indxs'][-1], :3])
    print('b is:')
    print(bias_input[batch_data['input_indxs'][-1], :3])
    print(bias_outpt[batch_data['outpt_indxs'][-1], :3])

  def _console_print(self, session, return_dict, batch_data, print_extra,
                     print_every):
    """Console printer.

    Args:
      session: the tf session
      return_dict: variables from train step
      batch_data: batch from train step
      print_extra: whether to print extra diagnostics
      print_every: how often to print
    Returns: (nothing)
    """
    percent_done = 100.0 * self._total_examples_trained / (
        self._num_records * self._iters)
    if print_every > 0 and return_dict['step'] % print_every == 0:
      print('---- iter %d, updates [%d, %d], last pair (%d, %d, %0.9f):' %
            ((self._iter, self._total_examples_trained -
              len(batch_data['scores']), self._total_examples_trained - 1) +
             (batch_data['input_indxs'][-1], batch_data['outpt_indxs'][-1],
              batch_data['scores'][-1])))
      print('---- %0.4f%% done: avg_cost %0.5f, sum_cost %0.5f' %
            (percent_done, self._sum_cost /
             (self._total_examples_trained), self._sum_cost))
      if np.isnan(np.sum(return_dict['est_score'])):
        print(return_dict['est_score'])
        print(return_dict['diff'])
        print(return_dict['loss'])
      print('---- est_score: %0.5f, diff: %0.5f, loss: %0.5f' %
            (np.sum(return_dict['est_score']), np.sum(
                return_dict['diff']), np.sum(return_dict['loss'])))
      print('-------- will do %d iters' % self._iters)
      if self._print_weight_diagnostics:
        self._print_weights(session, batch_data)
      if self._adv_labels:
        print('----adv_loss_g/d: %0.5f/%0.5f' %
              (return_dict['adv_loss_g'], return_dict['adv_loss_d']))
      if print_extra:
        print('word_pred: %0.9f, bias_pred: %0.9f, cvrt_pred: %0.9f' %
              (np.sum(return_dict['word_pred']), np.sum(
                  return_dict['bias_pred']), np.sum(return_dict['cvrt_pred'])))
        self._print_extra_diagnostics(session)
      print('===============================================')

  def _train_model_thread(self,
                          session,
                          cooccurrence_filename,
                          checkpoint_every,
                          checkpoint_dir,
                          print_extra_diagnostics=False,
                          print_every=-1,
                          kill_after=-1):
    """Trains glove model and saves word vectors.

    Args:
      session: a tensorflow session
      cooccurrence_filename: location of binary cooccurrences
      checkpoint_every: how often to checkpoint (counted in batches)
      checkpoint_dir: directory to save checkpoints in
      print_extra_diagnostics: whether to show extra diagnostics
      print_every: update the console every this number of steps
      kill_after: stop each iteration after this many updates
    Returns: (nothing)
    """
    cooccurrence_count = 0
    max_count = 0
    f = None
    if cooccurrence_filename:
      f = open(cooccurrence_filename)
    else:
      max_count = len(self._cooccurrences)
    continue_training = True
    while continue_training:
      batch_data, continue_training = self._get_batch(cooccurrence_filename,
                                                      cooccurrence_count,
                                                      max_count, f)
      len_batch = len(batch_data['scores'])
      if len_batch > 0:
        # Train on the batch
        training_dict = {
            self._input_indxs: batch_data['input_indxs'],
            self._outpt_indxs: batch_data['outpt_indxs'],
            self._neg_samples_tensor: batch_data['neg_samples'],
            self._scores: batch_data['scores'],
            self._weight: batch_data['weight']
        }
        if self._adv_labels:
          training_dict.update({
              self._adv_positives: batch_data['positive_indxs'],
              self._adv_negatives: batch_data['negative_indxs']
          })
        vars_to_get = {
            'train': self._train,
            'step': self._global_step,
            'diff': self._diff,
            'loss': self._loss,
            'est_score': self._est_score,
            'word_pred': self._word_pred,
            'bias_pred': self._bias_pred,
            'cvrt_pred': self._cvrt_pred
        }
        if self._adv_labels:
          vars_to_get.update({
              'adv_loss_g': self._adv_loss_g,
              'adv_loss_d': self._adv_loss_d,
              'train_d': self._train_d,
              'train_g': self._train_g
          })
        if self._covariate_size > 0:
          vars_to_get.update({'cvrt_pred': self._cvrt_pred})
        return_dict = session.run(vars_to_get, feed_dict=training_dict)
        self._sum_cost += return_dict['loss']
        if self._adv_labels:
          self._sum_adv_cost_g += return_dict['adv_loss_g']
          self._sum_adv_cost_d += return_dict['adv_loss_d']
        cooccurrence_count += len_batch
        self._total_examples_trained += len_batch

        # Report to console
        self._console_print(session, return_dict, batch_data,
                            print_extra_diagnostics, print_every)

      if kill_after > -1 and return_dict['step'] >= kill_after:
        return
      if (checkpoint_dir and checkpoint_every >= 0 and
          return_dict['step'] % checkpoint_every == 0):
        print('>>>>>>>>>>>>>>>checkpointing<<<<<<<<<<<<<<<<<<<<')
        checkpoint_prefix = os.path.join(checkpoint_dir,
                                         'chkpnt%d' % return_dict['step'])
        _ = self._get_model_return_dict(session, checkpoint_prefix,
                                        checkpoint_prefix)

    # Close the file, if used
    if cooccurrence_filename:
      f.close()

  def _make_keyed_weights(self, weights):
    """Makes key, weight vector representation of matrix with model tokens.

    Args:
      weights: weights indexed by vocab_index_lookup

    Returns:
      weight vector dict keyed by token
    """
    return dict(
        zip(self._tokens,
            [weights[self._vocab_index_lookup[t]] for t in self._tokens]))

  def _get_model_return_dict(self,
                             session,
                             output_prefix=None,
                             covariate_weight_output=None):
    """Gets a model weight dictionary for main training function.

    Args:
      session: a tensorflow session
      output_prefix: prefix for all parameters except covariate transformation
      covariate_weight_output: prefix for covariate transformation
    Returns:
      return_dict: dict with weight dictionaries.
    """
    return_dict = {
        'topo_input': None,
        'topo_outpt': None,
        'meta_input': None,
        'meta_outpt': None,
        'meta_trans_input': None,
        'meta_trans_outpt': None
    }

    # Get word embeddings
    topo_input, topo_outpt = session.run(
        [self._word['input'], self._word['outpt']])
    return_dict['topo_input'] = self._make_keyed_weights(topo_input)
    return_dict['topo_outpt'] = self._make_keyed_weights(topo_outpt)

    if self._covariate_size > 0:
      # Get covariate embeddings
      cvrt_input, cvrt_outpt = session.run(
          [self._cvrt['input'], self._cvrt['outpt']])
      return_dict['meta_input'] = self._make_keyed_weights(cvrt_input)
      return_dict['meta_outpt'] = self._make_keyed_weights(cvrt_outpt)
      if output_prefix is not None:
        self._save_weights(output_prefix + '_cvrtvecs.txt',
                           cvrt_input + cvrt_outpt)

      if self._use_monet:
        # Project base word vectors one more time, store in word embeds
        u = session.run(self._u)
        topo_input -= self._db_level * np.matmul(
            u, np.matmul(np.transpose(u), topo_input))
        topo_outpt -= self._db_level * np.matmul(
            u, np.matmul(np.transpose(u), topo_outpt))
        return_dict['topo_input'] = self._make_keyed_weights(topo_input)
        return_dict['topo_outpt'] = self._make_keyed_weights(topo_outpt)
        if output_prefix is not None:
          self._save_weights(
              output_prefix + '.txt',
              np.concatenate([topo_input + topo_outpt, cvrt_input + cvrt_outpt],
                             axis=1))

      # Get covariate transformation
      (return_dict['meta_trans_input'],
       return_dict['meta_trans_outpt']) = session.run([
           self._cvrt_transformation['input'],
           self._cvrt_transformation['outpt']
       ])
      if output_prefix is not None:
        np.savetxt(
            covariate_weight_output + '_cvrt_projections.txt',
            np.concatenate([
                return_dict['meta_trans_input'], return_dict['meta_trans_outpt']
            ],
                           axis=1))
    # Save topo embeddings
    wordvec_tag = 'wordvecs' if self._covariate_size > 0 else ''
    if output_prefix is not None:
      self._save_weights(output_prefix + '_%s.txt' % wordvec_tag,
                         topo_input + topo_outpt)
    return return_dict

  def train_model(self,
                  session,
                  iters,
                  alpha,
                  xmax,
                  eta,
                  regress_out_covariates,
                  covariate_weight_output=None,
                  output=None,
                  cooccurrence_filename=None,
                  print_every=-1,
                  print_extra_diagnostics=False,
                  print_weight_diagnostics=False,
                  checkpoint_every=-1,
                  checkpoint_dir=None,
                  db_level=1.0,
                  batch_size=1,
                  kill_after=-1,
                  init_weight_dir=None,
                  use_w2v=False,
                  neg_samples=5,
                  w2v_logit_max=10.0,
                  w2v_neg_sample_mean=False,
                  adv_lam=0.2,
                  adv_labels=None,
                  adv_lr=0.05,
                  adv_dim=None,):
    """Trains glove model and saves word vectors.

    Args:
      session: a tensorflow session
      iters: number of iterations through the corpus
      alpha: weighted diff scaling power
      xmax: weighted diff scaling threshold
      eta: initial learning rate
      regress_out_covariates: whether to regress out word vecs with cvrt vecs
      covariate_weight_output: location to save covariate transformation at
      output: output filename for word vectors.
      cooccurrence_filename: location of binary cooccurrences
      print_every: update the console every this number of steps
      print_extra_diagnostics: whether to show extra covariate diagnostics
      print_weight_diagnostics: whether to print weight/update values
      checkpoint_every: number of steps to checkpoint after
      checkpoint_dir: where to put checkpoint files
      db_level: ("debias level") - a double between 0.0 and 1.0 inclusive giving
        the strength of the debiasing. 0.0 is no debiasing, 1.0 is full.
      batch_size: number of cooccurrences to train at once
      kill_after: kill each iteration after this many updates
      init_weight_dir: if specified, will load weights from directory
      use_w2v: uses word2vec-like loss, adapted for co-occurrence counts
      neg_samples: number of negative samples for word2vec
      w2v_logit_max: logits in w2v model are capped (in absolute value) at this
      w2v_neg_sample_mean: negative sample loss is averaged, if True
      adv_dim: dimension of hidden layer of MLP adversary
      adv_lam: tuning parameter for adversarial loss. For now, this assumes you
        have input metadata and that the metadata is binary and one-dimensional.
        The adversary uses a length(adv_dim)-layer MLP with leaky ReLU
        activations. The loss is softmax cross-entropy.
      adv_labels: a {token, v} dict where v is a 2-length list of one-hot floats
      adv_lr: learning rate for all adversarial train ops
      adv_dim: dimension of hidden layer of MLP adversary

    Returns:
      if output and covariate_weight_output are not specified, a dict with:
        {}
      else, nothing - model is saved at the locations specified
    """
    try:
      assert (cooccurrence_filename is not None or
              self._cooccurrences is not None)
    except AssertionError:
      print(('Error: must specify a cooccurrence filename if model object '
             'was not given random walks.'))
      return {}

    # Set training parameters and other trackers
    self._alpha = alpha
    self._eta = eta
    self._xmax = xmax
    self._batch_size = batch_size
    self._learning_rate = eta
    self._iters = iters
    self._total_examples_trained = 0
    self._use_w2v = use_w2v
    self._neg_samples = neg_samples
    self._print_weight_diagnostics = print_weight_diagnostics
    self._w2v_logit_max = w2v_logit_max
    self._w2v_neg_sample_scale = 1.0
    if w2v_neg_sample_mean:
      self._w2v_neg_sample_scale *= self._neg_samples
    assert 0.0 <= db_level <= 1.0
    self._db_level = db_level
    self._adv_dim = adv_dim
    self._adv_lam = adv_lam
    self._adv_labels = adv_labels
    self._adv_lr = adv_lr
    if cooccurrence_filename:
      self._num_records = (
          os.path.getsize(cooccurrence_filename) /
          self._cooccurrence_fmt_length)
    else:
      self._num_records = len(self._cooccurrences)
    self._use_monet = regress_out_covariates
    self._init_weight_dir = init_weight_dir
    # Build the graph
    print('building graph...')
    self._build_graph()
    print('training...')
    # Iterate through cooccurrences multiple times
    for i in range(iters):
      self._iter = i + 1
      self._train_model_thread(session, cooccurrence_filename, checkpoint_every,
                               checkpoint_dir, print_extra_diagnostics,
                               print_every, kill_after)
      print('end of iter %d, avg_cost %0.5f' %
            (self._iter, self._sum_cost / self._total_examples_trained))
      if self._adv_labels:
        print('---- avg_adv_cost_d %0.5f' %
              (self._sum_adv_cost_d / self._total_examples_trained))
        print('---- avg_adv_cost_g %0.5f' %
              (self._sum_adv_cost_g / self._total_examples_trained))
    return self._get_model_return_dict(session, output, covariate_weight_output)
